import datetime
from typing import Sequence

from PySide6.QtCore import QRunnable, QObject, Signal, Slot
from vnpy.event import EventEngine, Event, EVENT_TIMER
from vnpy.trader.constant import Direction
from vnpy.trader.engine import MainEngine
from vnpy.trader.event import EVENT_ACCOUNT, EVENT_LOG, EVENT_POSITION, EVENT_ORDER, EVENT_TRADE
from vnpy.trader.object import LogData, PositionData, AccountData, TickData, OrderData
from vnpy_ctp import CtpGateway
from vnpy_scripttrader import ScriptEngine

from mquant.trader.object import AccountModelData, ExtendedLogData, ExtendedPositionData, RiskManagerModelData, \
    ExtendedAccountData, ExtendedOrderData, ExtendedTradeData, ExtendedScriptEngine
from mquant.trader.utils import remove_keys_from_dict
from mquant_sqlite.sqlite_database import RiskManagerModelTable


class WorkerSignals(QObject):
    account = Signal(Event)
    log = Signal(Event)
    position = Signal(Event)
    trade = Signal(Event)
    order = Signal(Event)


class Worker(QRunnable):
    def __init__(self, account, db):
        super().__init__()
        self.db = db
        self.engine: ScriptEngine
        self.account: AccountModelData = account
        self.signals = WorkerSignals()

        self.risk: RiskManagerModelData = self.db.load_riskManager_data_by_account(self.account.account)

    def run(self):
        event_engine = EventEngine()
        main_engine = MainEngine(event_engine)
        main_engine.add_gateway(CtpGateway)
        self.engine: ExtendedScriptEngine = ExtendedScriptEngine(main_engine, event_engine)
        self.engine.connect_gateway({
            "用户名": self.account.account,
            "密码": self.account.password,
            "经纪商代码": self.account.broker,
            "交易服务器": self.account.trade_server,
            "行情服务器": self.account.quotation_server,
            "产品名称": self.account.product_name,
            "授权编码": self.account.authorization_code
        }, "CTP")

        # 处理数据并发送结果信号
        event_engine.register(EVENT_ACCOUNT, self.process_account_event)
        event_engine.register(EVENT_LOG, self.process_log_event)
        event_engine.register(EVENT_POSITION, self.process_position_event)
        event_engine.register(EVENT_ORDER, self.process_order_event)
        event_engine.register(EVENT_TRADE, self.process_trade_event)
        event_engine.register(EVENT_TIMER, self.process_timer_event)

    def quit(self):
        self.engine.main_engine.close()

    @Slot()
    def process_account_event(self, event):
        data: AccountData = event.data
        extend_data: ExtendedAccountData = ExtendedAccountData(accountid=self.account.account,
                                                               init_balance=round(self.account.init_balance, 1),
                                                               balance=round(data.balance, 1),
                                                               frozen=round(data.frozen, 1),
                                                               thre_balance=round(float(
                                                                   1 - self.risk.max_loss) * self.account.init_balance,
                                                                                  1),
                                                               gateway_name=data.gateway_name)
        event = Event(type=event.type, data=extend_data)
        self.signals.account.emit(event)

    @Slot()
    def process_log_event(self, event):
        data: LogData = event.data
        extend_data: ExtendedLogData = ExtendedLogData(accountid=self.account.account,
                                                       msg=data.msg,
                                                       gateway_name=data.gateway_name,
                                                       level=data.level)
        event = Event(type=event.type, data=extend_data)
        self.signals.log.emit(event)

    @Slot()
    def process_position_event(self, event):
        data: PositionData = event.data
        data_dict = data.__dict__.copy()
        delete_keys = ['vt_symbol', 'vt_positionid']
        data_dict = remove_keys_from_dict(data_dict, delete_keys)
        extend_data: ExtendedPositionData = ExtendedPositionData(accountid=self.account.account,
                                                                 **data_dict)
        event = Event(type=event.type, data=extend_data)
        self.signals.position.emit(event)

    @Slot()
    def process_order_event(self, event):
        data: PositionData = event.data
        data_dict = data.__dict__.copy()
        delete_keys = ['vt_symbol', 'vt_orderid', 'reference']
        data_dict = remove_keys_from_dict(data_dict, delete_keys)
        extend_data: ExtendedOrderData = ExtendedOrderData(
            accountid=self.account.account, reference="Mtrader", **data_dict)
        event = Event(type=event.type, data=extend_data)
        self.signals.order.emit(event)

    @Slot()
    def process_trade_event(self, event):
        data: PositionData = event.data
        data_dict = data.__dict__.copy()
        delete_keys = ['vt_symbol', 'vt_orderid', 'vt_tradeid']
        data_dict = remove_keys_from_dict(data_dict, delete_keys)
        extend_data: ExtendedTradeData = ExtendedTradeData(
            accountid=self.account.account, **data_dict
        )
        event = Event(type=event.type, data=extend_data)
        self.signals.trade.emit(event)

    def process_timer_event(self, event):
        if self.risk.sleep:
            return

        # 获取持仓数据
        positions: Sequence[PositionData] = self.engine.get_all_positions()
        # 获取账户数据
        accounts: Sequence[AccountData] = self.engine.get_all_accounts()

        if not positions or (not accounts):
            return

        account = accounts[0]

        vt_symbols = []
        for position in positions:
            vt_symbols.append(position.vt_symbol)
        # 订阅数据
        self.engine.subscribe(vt_symbols)

        # 获取当前仓位下所有k线数据
        ticks: Sequence[TickData] = self.engine.get_ticks(vt_symbols=vt_symbols)
        all_non_empty = all(tick_data is not None for tick_data in ticks)
        if not all_non_empty:
            return

        if account.balance < self.account.init_balance * (1 - float(self.risk.max_loss)):
            self.engine.write_log(f"{datetime.datetime.now()}风控触发：{self.account.username}用户仓位强平！")

            self.risk.sleep = 1
            RiskManagerModelTable.update(sleep=self.risk.sleep).where(
                RiskManagerModelTable.account == self.account.account).execute()
            self.db.quit()

            # 取消之前的订单
            orders: Sequence[OrderData] = self.engine.get_all_active_orders()
            for order in orders:
                self.engine.cancel_order(order.orderid)

            for position in positions:
                if position.volume == 0:
                    continue

                if position.direction == Direction.LONG:
                    price = None
                    # 这个地方可以考虑用dataframe加速
                    for tick in ticks:
                        if tick.vt_symbol == position.vt_symbol:
                            price = tick.limit_down
                        else:
                            continue
                    self.engine.sell(vt_symbol=position.vt_symbol,
                                     price=price,
                                     volume=position.volume
                                     )
                elif position.direction == Direction.SHORT:
                    price = None
                    # 这个地方可以考虑用dataframe加速
                    for tick in ticks:
                        if tick.vt_symbol == position.vt_symbol:
                            price = tick.limit_up
                        else:
                            continue
                    self.engine.cover(vt_symbol=position.vt_symbol,
                                      price=price,
                                      volume=position.volume
                                      )
